#!/usr/bin/env python3

"""
Module for performing all tasks in specdb
"""

import datetime
import hashlib
from hashlib import sha256
import io
from io import StringIO
import json
import os
import re
import sys
import sqlite3
import tarfile
import time

import pandas as pd
import pynmrstar
from pynmrstar import utils, Loop
import ruamel.yaml

SQLITE_PAGE_SIZE_INDEX  = 16
SQLITE_HEADER_LENGTH    = 16
SQLITE_PAGE_COUNT_INDEX = 28


SQL_TYPES = {
	"TEXT" : "",
	"FLOAT" : 0.0,
	"INTEGER": 0,
	"BLOB" : ""
}

env_keys = ['SPECDB_DB', 'SPECDB_BACKUP', 'SPECDB_LOG', 'SPECDB_NAME']

table_order = [
	'user', 'project', 'target', 'construct', 'expression',
	'purification_batch', 'buffer', 'buffer_components', 'pst',
	'batch_components', 'spectrometer', 'probe', 'pulse_sequence']

param_files = ['acqus', 'audita.txt', 'log', 'procpar']


def init(location=""):
	"""
	Initialize a SpecDB database file. 
	
	Parameters
	----------
	+ location		file-system location to initialize a db
	
	Returns
	-------
	+ Bool			True if db was made, else False
	"""
	
	assert(location)
	specdb_path = os.path.abspath(os.path.dirname(__file__))
	sql_path = os.path.join(specdb_path, '../sql/specdb.sql')
	
	conn = sqlite3.connect(location)
	c = conn.cursor()
	with open(sql_path) as fp:
		c.executescript(fp.read())
		return True


def forms(table=None, num=None):
	"""
	Create a YAML form for SpecDB tables
	
	Parameters
	----------
	+ table		list, table names to generate forms for
	+ num		list, number of forms to make for each table
				default 1, if any number is given then len(table) == len(num)
	
	Returns
	-------
	form_out	String of YAML form
	"""
	
	if table is None:
		print("Argument `table` cannot be none")
		print("Aborting")
		sys.exit()
	
	if num is not None:
		if len(table) != len(num):
			print("the num argument must be same length as table")
			print("Aborting")
			sys.exit()
	
	# read schema
	specdb_path = os.path.abspath(os.path.dirname(__file__))
	sql_path = os.path.join(specdb_path, '../sql/specdb.sql')
	
	form_dic = ruamel.yaml.comments.CommentedMap()
	table_name = ""
	with open(sql_path, 'r') as fp:
		
		for line in fp.readlines():
			line = line.rstrip()
			
			if re.search(r'^CREATE ', line):
				tr = line.split()
				table_name = tr[2]
				
				if table_name in table:
					comment_split = line.split(' -- ')
					comment = comment_split[-1]
					form_dic[table_name] = ruamel.yaml.comments.CommentedMap()
					form_dic.yaml_add_eol_comment(comment, table_name, column=25)
			
			elif re.search(r' -- ', line):
				if table_name in table:
					sql_col  = line.split(' -- ')
					col_info = sql_col[0].split()
					comment  = sql_col[-1]
					
					column   = col_info[0]
					sql_type = col_info[1]
					
					form_dic[table_name][column] = SQL_TYPES[sql_type]
					pad = 30 - len(column)
					form_dic[table_name].yaml_add_eol_comment(
						comment, column, column=35)
	
	yaml = ruamel.yaml.YAML()
	yaml.preserve_quotes = True
	
	if num is not None:
		num = [int(i) for i in num]
		form = ruamel.yaml.comments.CommentedMap()
		for i, tbl in enumerate(table):
			form[tbl] = ruamel.yaml.comments.CommentedMap()
			comment = form_dic.ca.items[tbl][2].value
			form.yaml_add_eol_comment(comment, tbl)
			
			for x in range(num[i]):
				form[tbl][x] = ruamel.yaml.comments.CommentedMap()
				for k,v in form_dic[tbl].ca.items.items():
					comment = v[2].value[1:]
					form[tbl][x][k] = form_dic[tbl][k]
					form[tbl][x].yaml_add_eol_comment(comment, k)
	
		string_stream = StringIO()
		yaml.dump(form, string_stream)
		form_out = string_stream.getvalue()
		string_stream.close()
		return form_out
	else:
		string_stream = StringIO()
		yaml.dump(form_dic, string_stream)
		form_out = string_stream.getvalue()
		string_stream.close()
		return form_out


def check_yaml(file=None):
	"""
	Check to see that requested YAML follows expected structure. 
	
	Parameters
	----------
	+ file		path to YAML file to be checked
	
	Returns
    -------
    + True		if YAML file passes check
    + False		if YAML file fails check
	"""
	
	if file is None:
		print('no YAML file provided')
		print('must provide a YAML file to be checked')
		print('Aborting')
		sys.exit()
	
	# read yaml file to be checked
	with open(file, 'rt') as fp:
		yaml = ruamel.yaml.YAML()
		record = yaml.load(fp)
	
	# convert record to dictionary
	rec = json.loads(json.dumps(record))
	
	for tbl, dic in rec.items():
		
		template = forms(table=tbl)
		temp = ruamel.yaml.safe_load(template)
		
		temp = json.loads(json.dumps(temp))
		
		rec_keys = list(rec[tbl].keys())
		
		if isinstance(rec[tbl][rec_keys[0]], dict):
		
			for ind, data in record[tbl].items():
				for key in data:
					if key not in temp[tbl]:
						print(f'unknown key {key} in {file}')
						print('Aborting')
						print(key)
						print(temp)
						sys.exit()
						return False
						
		else:
			for key in dic:
				if key not in temp[tbl]:
					print(f'unknown key {key} in {file}')
					print('Aborting')
					print(key)
					print(temp)
					sys.exit()
					return False
	return True


def find_uniq_constraint(table=None, cursor=None):
	"""
	Find the columns in specified table with a unique constraint on it
	
	Parameters
	----------
	+ table		which table in database to look at
	+ cursor	sqlite3 cursor object to perform the select
	
	Returns
	-------
	+ cols		returns the columns in table that have a unique constraint
				as a list
	"""
	sql = f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}'"
	cursor.execute(sql)
	schema = cursor.fetchone()
	print(schema)
	entries = [ tmp.strip() for tmp in schema[0].splitlines()
		if tmp.find("UNIQUE")>=0 ]
	
	assert(len(entries) > 0)
	uniqs = entries[0][7:]
	first = uniqs.split(')')
	cols = first[0].split(', ')
	
	return cols


def condition_checker(columns=None, dic=None):
	"""
	Parameters
	----------
	+ columns		Name of SpecDB columns to query on
	+ dic			Data structure to check for consistency with SpecDB database
	
	Returns
	-------
	+ col_str		column to select on for the specific table requested
	+ condition		string that is condition for the SQLite SELECT statement
	"""
	
	if len(columns) == 1:
		col_str = columns[0]
	else:
		col_str = ', '.join(columns)
	
	condition = ''
	
	for i in range(len(columns)):
		
		condition += f"{columns[i]} == '{dic[columns[i]]}'"
		
		if i < len(columns)-1: condition += " AND "
		
	return col_str, condition


def read_acqus(path=None):
	"""
	Read an acquisiton status file to find temperature in the experiment
	
	Parameters
	----------
	+ path		file path to acqus file
	
	Returns
	-------
	+ temp		temperature float value
	"""
	if not os.path.isfile(path): return None
	
	temp = 0.0
	probe = ''
	pulse = ''
	with open(path, 'r') as fp:
		for line in fp.readlines():
			line = line.rstrip()
			if '##$TE= ' in line:
				info = line.split('= ')
				temp = info[-1]
				temp = float(temp)
			
			if '##$PULPROG= ' in line:
				info  = line.split('= ')
				pulse = info[-1]
				pulse = pulse.replace('<','')
				pulse = pulse.replace('>','')
			
			if '##$PROBHD= ' in line:
				info  = line.split('= ')
				probe = info[-1]
				probe = probe.replace(' ','_')
				probe = probe.replace('<','')
				probe = probe.replace('>','')
				
		return temp, probe, pulse
	
	return None


def read_audita(path=None):
	"""
	Read a Bruker `audita.txt` file for when pulse sequence started
	
	Parameters
	----------
	+ path		path to `audita.txt` to parse 
	
	Returns
	+ time		string of when pulse sequence started
	"""
	
	if not os.path.isfile(path): return None
	
	time = ''
	with open(path, 'r') as fp:
		for line in fp.readlines():
			line = line.rstrip()
			if 'started at ' in line:
				info = line.split()
				time = ' '.join((
					info[2],
					info[3]))
				return time


def insert_fid(files=None, cursor=None, dic=None):
	"""
	Insert time domain data into database
	
	Parameters
	----------
	+ files		list of files in FID directory
	+ cursor	SQLite cursor object to perform database operations
	+ dic		dictionary of information to insert into time_domain_datasets
				table
	
	Returns
	-------
	+ success	True/False if/ifnot insert successful
	+ str/dic	depends on the errors from table_inserter()
	"""
	
	found = {}
	for f in files:
		if os.path.isfile(f):
			name = os.path.basename(f)
			
			if name in param_files: found[name] = f
	
	if 'acqus' in found and 'audita.txt' in found:
		temp, probe, pulse = read_acqus(path=found['acqus'])
		exp_time = read_audita(path=found['audita.txt'])
		
		dic['temperature_from_data'] = temp
		dic['probe_info']            = probe
		dic['pulse_sequence']        = pulse
		dic['experiment_date']       = exp_time
	
# 	if 'procpar' in found and 'log' in found:
# 		temp, probe, pulse = read_procpar(path=found['procpar'])
#	
	fid_dirs     = files[0].split('/')
	zipped_path  = '/'.join(fid_dirs[:-1])
	up_path = '/'.join(fid_dirs[:-2])
	cwd = os.getcwd()
	os.chdir(up_path)
	
	cmd =  f"find ./{dic['subdir_name']} -maxdepth 1"
	cmd += f" -type f -print0 | xargs -0"
	cmd += f" tar -zcvf specdb.tar.gz"
	os.system(cmd)
	
	with open('specdb.tar.gz', 'rb') as fp: fbytes = fp.read()
	
	dic['zipped_dir'] = fbytes
	time_domain = [f for f in os.listdir(zipped_path)
		if os.path.isfile(os.path.join(zipped_path, f))]
	
	time_domain = [f for f in time_domain if f == 'fid' or f == 'ser']
	assert(len(time_domain) == 1)
	
	with open(os.path.join(zipped_path, time_domain[0]), 'rb') as fp: 
		fbytes = fp.read()
	
	readable_hash = hashlib.md5(fbytes).hexdigest()
	dic['md5checksum'] = readable_hash
	
	success, last_row = table_inserter( # go ahead and insert
		table='time_domain_dataset',
		record=dic,
		cursor=cursor)
	
	os.remove('specdb.tar.gz')
	if success is True: return success, dic
	else:               return success, last_row


def insert_constructor(col_names, row):
	"""
	Construct the SQL insert statement for data to be inserted in a SpecDB
	table.
	
	Parameters
	----------
	+ col_names		names of columns for the specific table we are inserting
					into. these strings are used as keys into the dictionary
					`row`.
	+ row			key/value pairs to be inserted into the SQLite table.
	
	Returns
	-------
	+ cols			table column names to be inserted.
	+ VALS			`,?` placeholder in SQLite Insert statement.
	+ vals			the actual values to be inserted into the table.
	"""
	cols = '('
	vals = []
	VALS = '('
	
	for i, c in enumerate(col_names):
		if c == 'id': continue
		if c == 'isotopic_labeling': continue
		cols += f'{c}'
		VALS += '?'
		
		if type(row[c]) == str:
			if len(row[c]) == 0:
				vals.append(None)
			else:
				vals.append(row[c])
		else:
			vals.append(row[c])
		
		if i+1 < len(col_names):
			cols += ', '
			VALS += ', '
		else:
			cols += ')'
			VALS += ')'
	
	return cols, VALS, vals


def table_inserter(table=None, record=None, cursor=None):
	"""
	Insert data into SQLite table from SpecDB
	
	Parameters
	----------
	+ table		name of table to insert into
	+ record	key/value pairs to be inserted into table
				keys are table column names, values are the data to be inserted
				for the respective column 
	+ cursor	sqlite3 cursor object
	
	Returns
	-------
	int			last rowid of row inserted
	"""
	assert(table is not None)
	assert(record is not None)
	assert(cursor is not None)
	#print('record', record)
	cursor.execute(f"select * from {table} limit 1")
	table_cols = [i[0] for i in cursor.description]
	if table == 'buffer_components':
		assert('isotopic_labeling' in table_cols)
		table_cols.remove('isotopic_labeling')
	#print('table_cols', record.keys())
	columns, vals_place, values = insert_constructor(table_cols, record)
	#print('columns', columns)
	#print('vals_place', vals_place)
	if values == [None] * len(values): return None
	sql = f"INSERT INTO {table} {columns} VALUES{vals_place}"
	#print(sql)
	#print(values)
	try:
		cursor.execute(sql, values) 
		return True, cursor.lastrowid
	except sqlite3.Error as e:
		return False, e.args[0]


def table_updater(table=None, record=None, cursor=None):
	"""
	Update data in a SQLite table from SpecDB
	
	Parameters
	----------
	+ table		name of table to update
	+ record	keye/value pairs to be used to update table
	+ cursor	sqlite3 cursor object
	
	Returns
	-------
	int			rowid of row that was updated
	"""
	assert(table is not None)
	assert(record is not None)
	assert(cursor is not None)
	
	cursor.execute(f"SELECT * from {table} limit 1")
	table_cols = [i[0] for i in cursor.description]
	
	columns, vals_place, values = insert_constructor(table_cols, record)
	if values == [None] * len(values): return None
	sql = f"INSERT OR REPLACE INTO {table} {columns} VALUES{vals_place}"
	try:
		cursor.execute(sql, values)
		return True, cursor.lastrowid
	except sqlite3.Error as e:
		return False, e.args[0]


def insert_logic(table=None, dic=None, write=False, cursor=None):
	"""
	Control logic for insertion into each of the tables
	
	Parameters
	----------
	+ table			which table to insert into
	+ dic			which data structure is given, read from JSON file
	+ write			whether to write new/update items in the database
	+ cursor		sqlite3 cursor object
	
	Returns
	------
	+ status		True/False
					True if insertion was successful
					False otherwise
	+ value			The lastrowid of table after successful insert
					OR
					dictionary of updated record
					OR
					dictionary of errors to write to LOG file
	"""
	
	# if nothing is filled out for this table/row, abort
	all_none = []
	for v in dic.items():
		if type(v) == str:
			if len(v) == 0: all_none.append(True)
		elif type(v) == int:
			if v == 0: all_none.append(True)
		elif type(v) == float:
			if v == 0.0: all_none.append(True)
	
	if len(all_none) == len(dic):
		msg =  f"Table: {table} is completely empty"
		msg += "\nAborting"
		return False, msg
				
	# find the columns that serve as unique constraints
	uniq_cols = find_uniq_constraint(table=table, cursor=cursor)
	
	# determine if they are defined in dic
	uniq_pres = [True for uc in uniq_cols if len(dic[uc]) > 0]
	if len(uniq_pres) == len(uniq_cols):
		# if they are defined, check if the remaining columns are empty
		# if they are, assume that the user wanted to pull in the info
		# for that table entity
		# otherwise, insert as normal
		
		col_names, cond = condition_checker( # build strings for SELECT command
			columns=uniq_cols,
			dic=dic
		)
		
		# checking if all but uniq columns are empty
		empty = True
		for k,v in dic.items():
			if k in uniq_cols: continue
			if type(v) == str:
				if len(v) > 0:
					empty = False
					break
			elif type(v) == int:
				if v != 0:
					empty = False
					break
			elif type(v) == float:
				if v != 0.0:
					empty = False
					break
		
		if empty: # attempt to fill JSON file with database info
			sql = f"SELECT {col_names} FROM {table} WHERE {cond}"
			cursor.execute(sql)
			results = cursor.fetchall()
			if len(results) == 0: # uniq col values don't exist and empty
				
				status = False
				msg =  f"nothing to insert on table {table}"
				msg += f"\nwhat was given:\n"
				msg += f"{json.dumps(dic)}"
				msg += "\n"
				
				return status, msg
				
			else: # they are in the database and rest are empty
				assert(len(results) == 1)
				
				sql = f"SELECT * FROM {table} WHERE {cond}"
				cursor.execute(sql)
				results = cursor.fetchall()
				assert(len(results) == 1)
				
				# set the JSON data to what is in db
				tmp_results = dict(results[0])
				del tmp_results['id']
				for k,v in tmp_results.items():
					if v is None: tmp_results[k] = ""
				
				dic = tmp_results
				return True, dic
		
		else: # not empty, user supplied some information, attempt to insert
			sql = f"SELECT * FROM {table} WHERE {cond}"
			cursor.execute(sql)
			results = cursor.fetchall()
			
			if len(results) == 0: # supplied info not in database
				if not write:
					msg =  "information provided is new\n"
					msg += json.dumps(dic, indent='\t')
					msg += "\nMust set --write to insert provided information"
					msg += " to database\nAborting"
					return False, msg
				
				else:
					success, last_row = table_inserter( # go ahead and insert
						table=table,
						record=dic,
						cursor=cursor)
					
					if type(last_row) is str:
						assert(success is False)
						
						msg =  f"SQLite error on insert on table {table}"
						msg += "\nErr Message:\n"
						msg += f"{last_row}"
						msg += "\ncheck the template form for instructions and "
						msg += "examples\n"
						msg += "ensure all ids this table relates to (i.e "
						msg += "constructs relate to targets) are inserter "
						msg += "already\n"
						msg += "Aborting"
						
						return success, msg
					
					return success, last_row
				
			else: # table unique columns actually in db, make sure they match
				assert(len(results) == 1)
				tmp_results = dict(results[0])
				del tmp_results['id']
				for k,v in dic.items():
					if type(v) == str:
						if len(v) == 0: dic[k] = None
				
				if tmp_results != dic:
					
					if not write:
						success = False
						msg =  f"requested data is different than the database"
						msg += "\ninput information:\n"
						msg += json.dumps(dic, indent='\t')
						msg += "\ndatabase information:\n"
						msg += json.dumps(tmp_results, indent='\t')
						msg += "\nAborting"
					
						return success, msg
						
					else:
						success, last_row = table_updater(
							table=table,
							record=dic,
							cursor=cursor
						)
						
						if type(last_row) is str:
							assert(success is False)
							
							msg =  f"SQLite error on insert on table {table}"
							msg += "\nErr Message:\n"
							msg += f"{last_row}"
							msg += "\ncheck the template form for examples and "
							msg += "instructions\n"
							msg += "ensure all ids this table relates to (i.e "
							msg += "constructs relate to targets) are inserted "
							msg += "already\n"
							msg += "Aborting"
						
							return success, msg
						
						return success, last_row
				
				else: return True, results[0]['id']
	else:
		
		msg =  f"unique column ids {uniq_cols} cannot be blank"
		msg += "\nmust provide some string for them"
		msg += "\nAborting"
		
		return False, msg


def insert(file=None, db=None, write=False):
	"""
	Add a single YAML file into SpecDB SQLite database.
	Could be used for any data item type in database.
	
	Parameters
	----------
	+ file			path to YAML file to be inserted
	+ db			path to specdb database file
	+ write		    whether to write new SpecDB records into the SpecDB
					database, or to update records in database
					Expires after each new write
	
	Returns
	-------
	True			If insertion successful
	False			If insertion is unsuccessful
	"""
	if file is None:
		print('no YAML file provided')
		print('must provide a YAML file to insert for specdb add')
		print('Aborting')
		sys.exit()
	
	if db is None:
		print('no specdb database file provided')
		print('must provide a .db file for insert to work')
		print('Aborting')
		sys.exit()
	
	# connect to database
	conn = sqlite3.connect(os.path.abspath(db))
	conn.row_factory = sqlite3.Row
	c = conn.cursor()
	
	# enforce foreign key constraints
	sql = "PRAGMA foreign_keys = ON"
	c.execute(sql)
	
	# check yaml file to be inserted that it has the expected keys
	if not check_yaml(file=file):
		print(f"JSON file {file} does not have expected keys")
		print("Aborting")
		sys.exit()
	
	# read json file to be inserted
	with open(file, 'rt') as fp:
		yaml = ruamel.yaml.YAML()
		record = yaml.load(fp)
	
	# convert record to plain dictionary
	record = json.loads(json.dumps(record))
	
	# insert data from yaml in the specific table order
	for tbl in table_order:
		# some information in specd.json is entered as a single dictionary
		# others require a list of dictionaries
		
		# not all tables need to be inserter at the same time
		# for example user information can be inserted separately
		if tbl not in record: continue
		
		rec_keys = list(record[tbl].keys())
		
		if isinstance(record[tbl][rec_keys[0]], dict):
		
			for ind, data in record[tbl].items():
					
				status, value = insert_logic(
					table=tbl,
					dic=data,
					write=write,
					cursor=c
				)
				
				if status is True:
					if type(value) is dict:
						record[tbl][ind] = value
						conn.commit()
					elif type(value) is int:
						conn.commit()
						continue
					else:
						print("Unexpected type")
						print("Aborting")
						return False
			
				elif status is False:
					if value == 0:
						return False
				
					assert(type(value) is not dict)
				
					print(value)
					return False
			
				else:
					print(f"Unexpected value {status}")
					print("Aborting")
					return False
			
		else:
			
			status, value = insert_logic(
				table=tbl,
				dic=record[tbl],
				write=write,
				cursor=c
			)
			
			if status is True:
				if type(value) is dict:
					conn.commit()
					record[tbl] = value
				elif type(value) is int:
					conn.commit()
					continue
				else:
					print("Unexpected type")
					print("Aborting")
					return False
		
			elif status is False:
				if value == 0:
					print(value)
					return False
			
				assert(type(value) is not dict)
			
				print(value)
				return False
		
			else:
				print(f"Unexpected value {status}")
				print("Aborting")
				return False
	
	# now insert sessions if present
	if 'session' in record:
		
		file_path = os.path.abspath(file)
		dirs = file_path.split('/')
		session_folder = dirs[-2]
		#print(session_folder)
		session_path = '/'.join(dirs[:-1])
		
		rec_keys = list(record['session'].keys())
		
		if isinstance(record['session'][rec_keys[0]], dict):
			record['session']['0']['folder_name'] = session_folder
			session_dic = record['session']['0']
			
		else:
			record['session']['folder_name'] = session_folder
			session_dic = record['session']
			
		status, value = insert_logic(
			table='session',
			dic=session_dic,
			cursor=c,
			write=write
		)
		
		if status is True:
			if type(value) is dict:
				record['session']['0'] = value
				conn.commit()
			
			elif type(value) is int:
				conn.commit()
			
			else:
				print("Unexpected type")
				print("Aborting")
				return False
			
		elif status is False:
			if value == 0: return False
			
			assert(type(value) is not dict)
			
			print(value)
			return False
		
		else:
			print(f"Unexpected value {status}")
			print("Aborting")
			return False
		
		session_id = value
		if 'time_domain_dataset' not in record:
			
			record['time_domain_data'] = dict()
			for sub in os.listdir(session_path):
			
				if os.path.isdir(os.path.join(session_path, sub)):
					
					subpath = os.path.join(session_path, sub)
					if 'fid' in os.listdir(subpath) or 'ser' in os.listdir(subpath):
						
						new_fid = {
							'subdir_name': sub,
							'pulse_sequence_nickname': None,
							'probe_id': None,
							'session_id': session_id,
							'zipped_dir': None,
							'md5checksum': None,
							'temperature_from_data': None,
							'experiment_date': None,
							'pulse_sequence': None,
							'probe_info': None,
							'pst_id': session_dic['pst_id']
						}
						
						files = [
							os.path.join(subpath, f) for f in os.listdir(
								subpath)
						]
						
						status, value = insert_fid(
							files=files,
							dic=new_fid,
							cursor=c)
						
						if status is True:
							assert(type(value) is dict)
							del value['zipped_dir']
							record['time_domain_data'][
								value['subdir_name']] = value
							conn.commit()
							#print(status, sub)
						
						elif status is False:
							print("Err message")
							print(value)
							print("moving to next fid")
							continue
						
						else:
							print("Unexpected type")
							print("Aborting")
							return False
		else:
			rec_keys = list(record['time_domain_dataset'].keys())
			if isinstance(record['time_domain_dataset'][rec_keys[0]], dict):
				for ind, data in record['time_domain_dataset'].items():
					new_fid = {
						'subdir_name': None,
						'pulse_sequence_nickname': None,
						'probe_id': None,
						'session_id': session_id,
						'zipped_dir': None,
						'md5checksum': None,
						'temperature_from_data': None,
						'experiment_date': None,
						'pulse_sequence': None,
						'probe_info': None,
						'pst_id': session_dic['pst_id']
					}
					
					for k, v in data.items(): new_fid[k] = v
				
					files = [
						os.path.join(
							session_path,
							new_fid['subdir_name'],
							f) for f in os.listdir(
								os.path.join(
									session_path,
									new_fid['subdir_name']))]
					
					status, value = insert_fid(
						files=files,
						dic=new_fid,
						cursor=c)
						
					if status is True:
						assert(type(value) is dict)
						del value['zipped_dir']
						record['time_domain_data'][ind] = value
						conn.commit()
						
					elif status is False:
						print("Err message")
						print(value)
						print("moving to next fid")
						continue
						
					else:
						print("Unexpected type")
						print("Aborting")
						return False
	
		conn.commit()
		conn.close()
		print(f"Inserted data from form file at {file}")
		
		with open(file, 'w') as fp:
			yaml = ruamel.yaml.YAML()
			yaml.dump(record, fp)
	
	else:
		
		conn.commit()
		conn.close()
		print(f"Inserted data from json file at {file}")
		
		with open(file, 'w') as fp:
			yaml = ruamel.yaml.YAML()
			yaml.dump(record, fp)
		
		return True


def summary(db=None, table=None):
	"""
	Display contents of a SpecDB table
	
	Parameters
	----------
	+ db		path to a specdb database file
	+ table		str, name of table to display contents of
	
	Returns
	-------
	print to stdout dataframe of SpecDB table contents, order alphabetically
	"""
	if db is None:
		print("Must provide a database file")
		print("Aborting")
		sys.exit()
	
	if table is None:
		print("Must provide a table to summarize")
		print("Aborting")
		sys.exit()
	
	pd.set_option('display.max_columns', None)
	pd.set_option('max_colwidth', 20)
	pd.set_option('display.max_rows', None)
	pd.set_option('display.width', 110)
	pd.set_option('max_rows', None)
	
	# connect to database
	conn = sqlite3.connect(os.path.abspath(db))
	conn.row_factory = sqlite3.Row
	c = conn.cursor()
	
	# collect the columns that define unique constraint for the table
	if table == 'summary':
		sql = f"SELECT * FROM {table}"
	else:
		uniq_cols = find_uniq_constraint(table=table, cursor=c)
		
		# perform sql select * query from provided table
		if len(uniq_cols) == 1:
			sql = f"SELECT * FROM {table} ORDER BY {uniq_cols[0]} ASC"
		elif len(uniq_cols) == 2:
			sql = f"SELECT * FROM {table} ORDER BY {uniq_cols[0]} ASC,"
			sql += f" {uniq_cols[1]} ASC"
		else:
			sql = f"SELECT * from {table}"
	
	df = pd.read_sql(sql, conn)
	blankindex = [''] * len(df)
	df.index = blankindex
	df = df.rename(columns=lambda x: x[:14])
	
	print()
	print(df)
	print()
	
	return


def flatten_query(sql, col, table, cursor):
	#print(sql)
	cursor.execute(sql)
	results = cursor.fetchall()
	flat = dict()
	
	for rr in results:
		for k in rr.keys():
			if k == 'id': continue
			if k == col: continue
			if (table, k) not in flat: flat[(table, k)] = list()
			flat[(table, k)].append(rr[k])
	
	
	return flat


def query_dbview(view, curr_view, table, fkeys, cursor):
	#all_fkeys[ctable][child] = {'parent':parent, 'ptable':ptable}
	for k, v in curr_view.items():
		#print(k,v)
		#continue
		if table in fkeys:
			if k[1] in fkeys[table]:
				view[k] = v
				parent_key = fkeys[table][k[1]]['parent']
				parent_table = fkeys[table][k[1]]['ptable']
				assert(type(v) == list and len(v) == 1)
				sql = f"select * from {parent_table} where {parent_key} = '{v[0]}' order by id ASC"
				#print(sql)
				fq = flatten_query(sql, '', parent_table, cursor)
				for k1 in fq:
					if k1 not in view: view[k1] = fq[k1]
				view = query_dbview(view, fq, parent_table, fkeys, cursor)
				del fkeys[table][k[1]]
			else:
				view[k] = v
		else:
			view[k] = v
	return view


def star_loop_constructor(tags, view, cat):
	
	if cat not in view: return {}
	
	loop_data = dict()
	
	maxlen = -1
	for t in view[cat]:
		if len(view[cat][t]) > maxlen:
				maxlen = len(view[cat][t])
	
	for col in view[cat]:
		if len(view[cat][col]) < maxlen:
			view[cat][col] += ['.'] * (len(view[cat][col]) - maxlen)
	
	loop_data = {t:view[cat][t] for t in tags if t in view[cat]}
	
	for k in loop_data:
		for i, val in enumerate(loop_data[k]):
			if type(val) != str: continue
			if '<' in val:
				print(val.replace('<',''))
				loop_data[k][i] = val.replace('<','')
			if '>' in val:
				loop_data[k][i] = val.replace('>','')
	
	#print(json.dumps(loop_data,indent=2))
	return loop_data


def fid2star(data=None, cursor=None, save=None):
	"""
	Take FID information from SpecDB into STAR
	
	Parameters
	----------
	+ data		SQLite row object from SpecDB summary query
	+ cursor	SQLite cursor object to perform queries
	+ save		path to save resulting STAR file
	
	Returns
	-------
	True		if STAR creation successful
	"""
	
	sql =  "SELECT m.name, p.* FROM sqlite_master m JOIN"
	sql += " pragma_foreign_key_list(m.name) p ON m.name != p.'table' WHERE"
	sql += " m.type = 'table' ORDER BY m.name"
	cursor.execute(sql)
	all_fks = cursor.fetchall()
	
	all_fkeys = dict()
	for fk in all_fks:
		#print(dict(fk))
		child  = fk[4]
		parent = fk[5]
		ptable = fk[3]
		ctable = fk[0]
		
		if ctable not in all_fkeys:
			all_fkeys[ctable] = dict()
		
		if child not in all_fkeys[ctable]:
			all_fkeys[ctable][child] = {'parent':parent, 'ptable':ptable}
			
	#print(json.dumps(all_fkeys, indent=2))
	#sys.exit()
	
	# gather conversion table
	cursor.execute("SELECT * from star_conversion")
	mappings = cursor.fetchall()
	
	translator = dict()
	for k in mappings:
		table = k[2]
		col   = k[1]
		tag   = k[3]
		frame = k[4]
		
		if table not in translator: translator[table] = dict()
		if col not in translator[table]: translator[table][col] = dict()
		if frame not in translator[table][col]: translator[table][col][frame] = []
		translator[table][col][frame].append(tag)
	
	#print(json.dumps(translator,indent=2))
	sql = f"SELECT * FROM time_domain_dataset where id == '{data['id']}'"
	cursor.execute(sql)
	results = cursor.fetchall()
	
	#print(dict(results[0]))
	
	view = dict()
	flat = flatten_query(sql, '', 'time_domain_dataset', cursor)
	dbview = query_dbview(view, flat, 'time_domain_dataset', all_fkeys, cursor)
	
	#print(list(dbview.keys()))
	#print(len(list(dbview.keys())))
	
	#print(dbview[('time_domain_dataset','pst_id')])
	sql = f"SELECT * FROM batch_components WHERE pst_id == '{dbview[('time_domain_dataset','pst_id')][0]}'"
	flat = flatten_query(sql, '', 'batch_components', cursor)
	moreview = query_dbview(dbview, flat, 'batch_components', all_fkeys, cursor)
	
	#print(json.dumps(list(moreview.keys()),indent=2))
	#print(len(list(moreview.keys())))
	
	sql = f"SELECT * from buffer_components WHERE buffer_id == '{dbview[('pst','buffer_id')][0]}'"
	flat = flatten_query(sql, 'buffer_id', 'buffer_components', cursor)
	fullview = query_dbview(moreview, flat, 'buffer_components', all_fkeys, cursor)
	
	#print(json.dumps(list(fullview.keys()),indent=2))
	#print(len(list(fullview.keys())))
	
	specdb_path = os.path.abspath(os.path.dirname(__file__))
	star_path = os.path.join(specdb_path, '../sql/template.str')
	
	# Fill up the STAR file
	# Load the template
	with open(star_path, 'r') as fp:
		entry = pynmrstar.Entry.from_file(fp)
	
	entry.entry_id = 'SpecDBQuery'
	# Obtain schema version
	schema = utils.get_schema()
	fullview['nmr_star_version'] = schema.version

	# Fill in timedomain specific tags not given in schema
	fullview['title'] = fullview[('project', 'project_id')][0]
	fullview['data_file_content_type'] = 'timedomain_data'
	
	"""
	_Experiment_file.Name
	_Experiment_file.Directory_path
	_Upload_data.Data_file_name
	_Upload_data.Data_file_content_type
	
	"""
	#print(save)
	save_split = save.split('/')
	save_path = '/'.join(save_split[:-1])
	if 'fid' in os.listdir(save_path):
		tdname = 'fid'
	elif 'ser' in os.listdir(save_path):
		tdname = 'ser'
	
	file_dir_path = '/'.join(save_split[-3:-1])+'/'+tdname
	#sys.exit()
	
	#print(json.dumps(translator, indent=2))
	
	specdb = dict()
	for k in fullview.keys():
		#print(k)
		if k[0] in translator:
			if k[1] in translator[k[0]]:
				for frame in translator[k[0]][k[1]]:
					if frame not in specdb: specdb[frame] = dict()
					for tag in translator[k[0]][k[1]][frame]:
						if tag not in specdb[frame]: specdb[frame][tag] = list()
						specdb[frame][tag].extend(fullview[k])
	
	specdb['nmr_star_version'] = schema.version
	specdb['title'] = fullview[('project', 'project_id')][0]
	specdb['experiment_list']['_Experiment_file.Name'] = [tdname]
	specdb['experiment_list']['_Experiment_file.Type'] = ['free-induction decay']
	specdb['experiment_list']['_Experiment_file.Directory_path'] = [file_dir_path]
	
	#print(json.dumps(specdb, indent=2))
	#sys.exit()
	for frame in entry:
		#print(frame.tag_prefix)
		#print(frame._category)
		cat = frame._category
		#print(cat)
		#if cat not in specdb: continue
		
		#continue
		for ftags in frame.tag_dict.keys():
			
			tt = ftags.capitalize()
			#print(ftags)
			#print(frame.tag_prefix+'.'+ftags)
			#print(frame[frame.tag_prefix+'.'+ftags])
			
			full_tag = frame.tag_prefix+'.'+tt
			
			if ftags in specdb:
				frame[ftags] = specdb[ftags][0]
			#print(full_tag)
			if cat in specdb:
				#print(cat)
				if full_tag in specdb[cat]:
					#print(cat, full_tag)
					frame[ftags] = specdb[cat][full_tag][0]
			
		if len(list(frame.loop_iterator())) > 0:
			for loop in frame.loop_iterator():
				tags = loop.get_tag_names()
				
				if '_Upload_data.Data_file_name' in tags:
					data_to_add = {
						'_Upload_data.Data_file_name':[tdname],
						'_Upload_data.Data_file_content_type':['free-induction decay']
					}
					loop.add_data(data_to_add)
					continue
				
				data_to_add = star_loop_constructor(tags, specdb, cat)
				
				if not data_to_add: continue
				else:
					#print(tags)
					loop.add_data(data_to_add)
					#print(loop)
				
				
				
	
	entry.add_missing_tags()
	#entry.normalize()
	entry.write_to_file(save)
	print("done")
	
	#print(entry.validate())
	
	return
	

def query(db=None, sql=None, star=False, output_dir=None):
	"""
	Perform a query against summary view of database
	Format the query results either in a directory hierarchy or in NMR-STAR
	for each retrieved FID
	
	Parameters
	----------
	+ db			path to a specdb database file
	+ sql			user provided SQL select statement, must be on summary
	+ star			whether to output a STAR file for each FID in the query
	+ output_dir	where to put the results
	
	Returns
	-------
	Directory structure of results with NMR-STAR at the end for each FID
	"""
	
	if db is None:
		print("Must provide a database file")
		print("Aborting")
		sys.exit()
	
	if sql is None:
		print("Must provide a SQL SELECT statement to execute")
		print("Aborting")
		sys.exit()
	
	if output_dir is None:
		print("Must provide an output directory to place query results")
		print("Aborting")
		sys.exit()
	
	if not re.search(r'^SELECT', sql):
		print(f"provided SQL does not look like a SELECT statement")
		print(f"provided SQL: {sql}")
		print("Aborting")
		sys.exit()
	
	if not re.search(r' summary ', sql):
		print("provided SQL does not appear to SELECT on summary")
		print(f" provided SQL: {sql}")
		print("Aborting")
		sys.exit()
	
	# connect to database
	conn = sqlite3.connect(db)
	conn.row_factory = sqlite3.Row
	c = conn.cursor()
	
	# execute query
	c.execute(sql)
	
	results = c.fetchall()
	results = [dict(res) for res in results]
	
	if not os.path.isdir(os.path.abspath(output_dir)):
		os.mkdir(os.path.abspath(output_dir))
	
	output_dir = os.path.abspath(output_dir)
	
	for res in results:
		date_info = res['experiment_date'].split()
		date = date_info[0].split('-')
		date = ''.join(date)
		ids = [res['pst_preparer'], res['pst_id'], date]
		
		dir_name = '_'.join(ids)
		
		if not os.path.isdir(os.path.join(output_dir, dir_name)):
			os.mkdir(os.path.join(output_dir, dir_name))
		
		f = io.BytesIO()
		tar_file = tarfile.open(fileobj=f, mode='w:')
		tar_info = tarfile.TarInfo('data_zipped')
		tar_info.size = len(res['zipped_dir'])
		tar_info.mtime = time.time()
		tar_file.addfile(tar_info, io.BytesIO(res['zipped_dir']))
		tar_file.close()
		os.chdir(os.path.join(output_dir, dir_name))
		
		with open('stream.tar.gz', 'wb') as fp:
			fp.write(f.getvalue())
		
		os.system("tar -xf stream.tar.gz")
		os.system("tar -xf data_zipped")
		
		os.remove("stream.tar.gz")
		os.remove("data_zipped")
		print(os.getcwd())
		
		sql =  f"SELECT subdir_name FROM time_domain_dataset WHERE"
		sql += f" id='{res['id']}'"
		
		c.execute(sql)
		subname = c.fetchall()[0]
		#print(subname['subdir_name'])
		if star:
			star_path = os.path.join(
				output_dir, dir_name, subname['subdir_name'], 'fid.star')
			#print(output_dir)
			#print(dir_name)
			#print(star_path)
			#print(os.getcwd())
			#sys.exit()
			out = fid2star(
				data=res,
				cursor=c,
				save=star_path
			)


def backup(db=None, object_dir='objects', backup_file='backup.txt'):
	"""
	This function performs the incremental backup 
	This function is taken with slight modifications from the following Github
	repository: https://github.com/nokibsarkar/sqlite3-incremental-backup.git
	
	For this function, we implemented the python version of the backup at 
	sqlite3-incremental-backup/python/sqlite3backup/python 
	
	Much credit to Github user nokibsarkar
	
	Parameters
	----------
	+ db				path to sqlite database to backup
	+ object_dir		path to objects directory where all pages will reside
	+ backup_file		file to save sha256 hashes of pages
	
	Returns
	-------
	True	if backup successful
	"""
	
	page_size = 0
	# Open the database.
	with open(db, "rb") as db_file_object:
		assert(
			db_file_object.read(SQLITE_HEADER_LENGTH) == b"SQLite format 3\x00")
		db_file_object.seek(SQLITE_PAGE_SIZE_INDEX, os.SEEK_SET)
		page_size = int.from_bytes(db_file_object.read(2), 'little') * 256
		db_file_object.seek(SQLITE_PAGE_COUNT_INDEX, os.SEEK_SET)
		page_count = int.from_bytes(db_file_object.read(4), 'big')
	
	pages = []
	with open(db, "rb") as db_file_object:
		for page_number in range(page_count):
			db_file_object.seek(page_number * page_size, os.SEEK_SET)
			page = db_file_object.read(page_size)
			hash = sha256(page).hexdigest()
			directory, filename = hash[:2], hash[2:]
			file_path = os.path.join(object_dir, directory, filename)
			if not os.path.exists(file_path): # 
				os.makedirs(os.path.dirname(file_path), exist_ok=True)
				with open(file_path, "wb") as file_object:
					file_object.write(page)
			pages.append(hash)
	
	# Write the pages to the object directory.
	with open(backup_file, 'w') as fp:
		fp.write('\n'.join(pages))

	return True


def restore(backup=None, backup_file=None, object_dir=None):
	"""
	This function performs the restore function from an incremental backup 
	This function is taken with slight modifications from the following Github
	repository: https://github.com/nokibsarkar/sqlite3-incremental-backup.git
	
	For this function, we implemented the python version of the backup at 
	sqlite3-incremental-backup/python/sqlite3backup/python 
	
	Much credit to Github user nokibsarkar
	
	Parameters
	----------
	+ backup			path to sqlite database to restore to
	+ object_dir		path to objects directory where all pages will reside
	+ backup_file 		file to save sha256 hashes of pages
	
	Returns
	-------
	True	if restore successful
	"""
	
	# Read the pages from the backup file
	with open(backup_file, 'r') as fp:
		pages = fp.read().split('\n')
	
	# Open the database.
	with open(backup, "wb") as db_file_object:
		# Iterate thourgh the pages and write them to the database.
		for page in pages:
			path = os.path.join(object_dir, page[:2], page[2:])
			with open(path, "rb") as file_object:
				db_file_object.write(file_object.read())
	
	# Restoration is complete
	return True
